package com.yeskery.nut.extend.auth;

import com.yeskery.nut.core.HttpHeader;
import com.yeskery.nut.core.RequestHandler;
import com.yeskery.nut.core.ResponseCode;
import com.yeskery.nut.core.ResponseNutException;
import com.yeskery.nut.util.StringUtils;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.function.BiFunction;

/**
 * Basic认证拦截器
 * @author sprout
 * @version 1.0
 * 2024-06-16 01:13
 */
public class BasicAuthInterceptor extends AuthPlugin {

    /** Basic认证前缀 */
    private static final String BASIC_AUTH_PREFIX = "Basic ";

    /** 默认领域描述 */
    private static final String DEFAULT_REALM = "Nut Basic Authorization";

    /** 认证参数数量 */
    private static final int AUTH_PARAMETER_SIZE = 2;

    /**
     * 构建Basic认证插件
     * @param authenticateFunction Basic认证函数
     * @param realm 领域描述
     * @param charset 请求头传输字符串的字符集
     * @param unauthorizedHandler 未通过认证时的处理
     */
    public BasicAuthInterceptor(BiFunction<String, String, Boolean> authenticateFunction, String realm, Charset charset, RequestHandler unauthorizedHandler) {
        super((request, controller) -> {
            String value = request.getHeader(HttpHeader.AUTHORIZATION);
            if (StringUtils.isEmpty(value) || !value.startsWith(BASIC_AUTH_PREFIX) || value.length() == BASIC_AUTH_PREFIX.length()) {
                return false;
            }
            value = new String(Base64.getDecoder().decode(value.substring(BASIC_AUTH_PREFIX.length())), charset);
            String[] values = value.split(":");
            return values.length == AUTH_PARAMETER_SIZE && authenticateFunction.apply(values[0], values[1]);
        }, (request, response, execution) -> {
            response.addHeader(HttpHeader.WWW_AUTHENTICATE, BASIC_AUTH_PREFIX + "realm=\"" + realm + "\"");
            unauthorizedHandler.handle(request, response, execution);
        });
    }

    /**
     * 构建Basic认证插件
     * @param authenticateFunction Basic认证函数
     * @param realm 领域描述
     * @param charset 请求头传输字符串的字符集
     */
    public BasicAuthInterceptor(BiFunction<String, String, Boolean> authenticateFunction, String realm, Charset charset) {
        this(authenticateFunction, realm, charset, (request, response, execution) -> {throw new ResponseNutException(ResponseCode.UNAUTHORIZED);});
    }

    /**
     * 构建Basic认证插件
     * @param authenticateFunction Basic认证函数
     * @param realm 领域描述
     */
    public BasicAuthInterceptor(BiFunction<String, String, Boolean> authenticateFunction, String realm) {
        this(authenticateFunction, realm, StandardCharsets.UTF_8);
    }

    /**
     * 构建Basic认证插件
     * @param authenticateFunction Basic认证函数
     */
    public BasicAuthInterceptor(BiFunction<String, String, Boolean> authenticateFunction) {
        this(authenticateFunction, DEFAULT_REALM);
    }

    /**
     * 构建Basic认证插件
     * @param username 用户名
     * @param password 密码
     * @param realm 领域描述
     * @param charset 请求头传输字符串的字符集
     */
    public BasicAuthInterceptor(String username, String password, String realm, Charset charset) {
        this((u, p) -> username.equals(u) && password.equals(p), realm, charset);
    }

    /**
     * 构建Basic认证插件
     * @param username 用户名
     * @param password 密码
     * @param realm 领域描述
     */
    public BasicAuthInterceptor(String username, String password, String realm) {
        this(username, password, realm, StandardCharsets.UTF_8);
    }

    /**
     * 构建Basic认证插件
     * @param username 用户名
     * @param password 密码
     */
    public BasicAuthInterceptor(String username, String password) {
        this(username, password, DEFAULT_REALM);
    }
}
